import os
import random
import json
from tqdm import tqdm
import argparse
import pathlib
from load_aokvqa import load_aokvqa, get_coco_path
import ollama
from collections import Counter

random.seed(0)


def get_qwen_result(image_path, prompt, args):

    messages = [{
        "role": "user",
        "content": prompt
    }]

    if os.path.exists(image_path):  # 这里可以添加条件检查，以确保路径有效
        messages[0]["images"] = [image_path]

    response = ollama.chat(
        model="llama3.2-vision:11b",
        stream=False,
        messages=messages,
        options={
            "temperature": args.temperature,
            "max_tokens": args.max_tokens,
            "top_p": args.top_p,
            "frequency_penalty": args.frequency_penalty,
            "presence_penalty": args.presence_penalty,
            "stop": ["\n"]
        }
    )

    output = response['message']['content']

    return output

def prompt_element(d, context=None, include_choices=False, answer=False):
    return (f"Context: {context}\n" if context is not None else '') + \
            f"Q: {d['question']}\n" + \
           (f"Options: {', '.join(d['choices'])}.\n" if include_choices else '') + \
            f"A:" + (f"The correct answer is {d['choices'][d['correct_choice_idx']]},because {d['rationales']}" if answer else '')

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
    parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True)
    parser.add_argument('--n', type=int, default=0, dest='num_examples')
    parser.add_argument('--train-context', type=argparse.FileType('r'), dest='train_context_file')
    parser.add_argument('--prefix', type=str, default='', dest='prompt_prefix')
    parser.add_argument('--include-choices', action='store_true', dest='include_choices')
    parser.add_argument('--context', type=argparse.FileType('r'), dest='context_file')
    parser.add_argument('--out', type=argparse.FileType('w'), required=True, dest='output_file')
    parser.add_argument('--temperature', type=float, default=0.5)
    parser.add_argument('--max_tokens',
                        type=int,
                        default=512,
                        help='The maximum number of tokens allowed for the generated answer.')
    parser.add_argument('--top_p', type=float, default=1.0)
    parser.add_argument('--frequency_penalty', type=float, default=0.0)
    parser.add_argument('--presence_penalty', type=float, default=0.0)
    args = parser.parse_args()

    coco_dir = "/home/test/yxl/MCoT/data/COCO"
    train_set = load_aokvqa(args.aokvqa_dir, 'train')
    eval_set = load_aokvqa(args.aokvqa_dir, args.split)

    train_context = {}
    context = {}
    if args.context_file is not None:
        train_context = json.load(args.train_context_file)
        context = json.load(args.context_file)

    predictions = {}

    prompt_examples = random.sample(train_set, args.num_examples)

    for d in tqdm(eval_set):
        q = d['question_id']

        prompt = args.prompt_prefix

        for e in prompt_examples:
            prompt += prompt_element(e,
                                     context=train_context.get(q, None),
                                     include_choices=args.include_choices,
                                     answer=True
                                     )
            prompt += '\n\n'

        # prompt += "Please answer the following question in the form \"The correct answer is ,because\" based on the above example \n\n"
        prompt += "Let\'s think step by step\n"
        prompt += prompt_element(d,
                                 context=context.get(q, None),
                                 include_choices=True,
                                 answer=False
                                 )
        image_path = get_coco_path('val', d['image_id'], coco_dir)

        # output = get_qwen_result(image_path, prompt, args)

        all_outputs = []
        for _ in range(5):
            output = get_qwen_result(image_path, prompt, args)  # 'A', ..., 'E'
            all_outputs.append(output)
        counter = Counter(all_outputs)
        final_output = counter.most_common(1)[0][0]
        predictions[q] = final_output

    json.dump(predictions, args.output_file)



if __name__ == '__main__':
    main()
